{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from gensim.models.word2vec import Word2Vec\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import jieba\n",
    "import re\n",
    "from sklearn.externals import joblib # 结构化数据以二进制形式丢出来\n",
    "from sklearn.svm import SVC\n",
    "import sys\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"载入数据，清理数据，做分词处理，切分训练集与测试集\"\"\"\n",
    "# 正则清洗句子中的各种符号\n",
    "def regular_clean(sen):\n",
    "    sen = re.sub('…{2,100}', '…', sen)\n",
    "    sen = re.sub(',{2,100}', ',', sen)\n",
    "    sen = re.sub('，{2,100}', ',', sen)\n",
    "    sen = re.sub('\\.{3,100}', '...', sen)\n",
    "    sen = re.sub('。{2,100}', '。', sen)\n",
    "    sen = re.sub('\\?{3,100}', '?', sen)\n",
    "    sen = re.sub('？{2,100}', '？', sen)\n",
    "    sen = re.sub('!{2,100}', '!', sen)\n",
    "    sen = re.sub('！{2,100}', '!', sen)\n",
    "    sen = re.sub('、{2,100}', '、', sen)\n",
    "    sen = re.sub('-{2,100}', '-', sen)\n",
    "    sen = re.sub('[ ]{2,100}', ' ', sen)\n",
    "    return sen\n",
    "\n",
    "# 创建停用词列表\n",
    "def stopwordslist():\n",
    "    stopwords_file = os.path.join('.','datas','chinesestopwords.txt')\n",
    "    stopwords = [line.strip() for line in open(stopwords_file,encoding='UTF-8').readlines()]\n",
    "    return stopwords\n",
    "\n",
    "# 返回分词结果\n",
    "def cutwords(sten):\n",
    "    sten = regular_clean(sten)\n",
    "    tokens = jieba.cut(sten,cut_all=True)\n",
    "    stopwords = stopwordslist()\n",
    "    stopwords.extend([' ','\\t','\\n'])\n",
    "    return [x for x in tokens if x not in stopwords]\n",
    "#     return [x for x in tokens]\n",
    "\n",
    "def load_file_and_preprocessing():\n",
    "    neg_words,pos_words = [],[]\n",
    "    datafile = os.path.join('.','datas','goods_zh.tsv')\n",
    "    with open(datafile,encoding=\"utf-8\") as f:\n",
    "        for line in tqdm(f):\n",
    "            parts = line.split(\",,\")\n",
    "            if parts[-1].strip() == \"0\":\n",
    "                neg_words.append(cutwords(parts[0].strip()))\n",
    "            elif parts[-1].strip() == \"1\":\n",
    "                pos_words.append(cutwords(parts[0].strip()))\n",
    "            else:\n",
    "                continue\n",
    "    return pos_words,neg_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:00,  9.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载文件中...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "101058it [04:45, 353.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train, x_test, y_train, y_test保存完成!\n"
     ]
    }
   ],
   "source": [
    "print(\"加载文件中...\")\n",
    "pos_words,neg_words = load_file_and_preprocessing()\n",
    "tags = np.concatenate((np.ones(len(pos_words)),np.zeros(len(neg_words))))\n",
    "x_train, x_test, y_train, y_test = train_test_split(np.concatenate((pos_words,neg_words)),\n",
    "                                                    tags,\n",
    "                                                    test_size=0.2)\n",
    "np.save('./datas/svm_data/x_train.npy',x_train) # x就是数据，y就是标签\n",
    "np.save('./datas/svm_data/x_test.npy',x_test)\n",
    "np.save('./datas/svm_data/y_train.npy',y_train)\n",
    "np.save('./datas/svm_data/y_test.npy',y_test)\n",
    "print(\"x_train, x_test, y_train, y_test保存完成!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "睡衣/收到/男朋友/朋友/喜欢/款式/简单/大方/面料/软/舒服/睡衣/他家/买 1.0\n",
      "鞋子/不错/味道/不错/透气/透气性/气性/棒/晒/几张/分享 1.0\n",
      "没/发/耳机/评价/耳机/客服/回应 0.0\n",
      "说/懒得/退换/星星/配送/小哥/态度/服务 0.0\n",
      "质量 0.0\n",
      "喜欢/图片/一模一样 1.0\n",
      "黑色/手机/配/白色/充电/充电器/电器/耳机/转换/绝/! 0.0\n",
      "用户/未填写/填写/评价/价内/内容 1.0\n",
      "不好/降价/保价/太/气人 0.0\n",
      "老/顾客/这款/店家/搞活/活动/拍下/划算 1.0\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(\"/\".join(x_train[i]),y_train[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"对每个句子的所有词向量取均值，生成一个句子的vector\"\"\"\n",
    "def build_sentence_vector(text,size,imdb_w2v):\n",
    "    vec = np.zeros(size).reshape((1,size))\n",
    "    count = 0\n",
    "    for word in text:\n",
    "        try:\n",
    "            vec += imdb_w2v[word].reshape((1,size))\n",
    "            count += 1\n",
    "        except KeyError:\n",
    "            continue\n",
    "    if count != 0:\n",
    "        vec /= count \n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\24544\\.conda\\envs\\chatbot\\lib\\site-packages\\ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n",
      "  import sys\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[train_vecs.shape]: (80846, 300)\n",
      "[test_vecs.shape]: (20212, 300)\n"
     ]
    }
   ],
   "source": [
    "\"\"\"计算词向量\"\"\"\n",
    "def get_train_vecs(x,x_train,x_test):\n",
    "    n_dim = 300\n",
    "    # 初始化模型和词表\n",
    "    imdb_w2v = Word2Vec(size=n_dim,min_count=10,seed=1)\n",
    "    imdb_w2v.build_vocab(x)# 建模\n",
    "    imdb_w2v.save(\"./datas/svm_data/w2v_model.pkl\")\n",
    "    # 在训练集上\n",
    "    imdb_w2v.train(x_train,total_examples=imdb_w2v.corpus_count, epochs=50)\n",
    "    train_vecs = np.concatenate([build_sentence_vector(z,n_dim,imdb_w2v) for z in x_train])\n",
    "#     train_vecs = scale(train_vecs)\n",
    "    np.save(\"./datas/svm_data/train_vecs.npy\",train_vecs)\n",
    "    print(\"[train_vecs.shape]:\",train_vecs.shape)\n",
    "    \n",
    "    # 在测试集上训练\n",
    "    imdb_w2v.train(x_test,total_examples=imdb_w2v.corpus_count, epochs=50)\n",
    "    test_vecs = np.concatenate([build_sentence_vector(z,n_dim,imdb_w2v) for z in x_test])\n",
    "#     test_vecs = scale(test_vecs)\n",
    "    np.save(\"./datas/svm_data/test_vecs.npy\",test_vecs)\n",
    "    print(\"[test_vecs.shape]:\",test_vecs.shape)\n",
    "get_train_vecs(np.concatenate((pos_words, neg_words)),x_train,x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LibSVM]0.9309321195329507\n"
     ]
    }
   ],
   "source": [
    "# 加载保存的文件数据\n",
    "def get_data():\n",
    "    train_vecs = np.load('./datas/svm_data/train_vecs.npy')\n",
    "    y_train = np.load('./datas/svm_data/y_train.npy')\n",
    "    test_vecs = np.load('./datas/svm_data/test_vecs.npy')\n",
    "    y_test = np.load('./datas/svm_data/y_test.npy')\n",
    "    return train_vecs,y_train,test_vecs,y_test\n",
    "\n",
    "# 训练svm模型train svm model with sklearn\n",
    "def svm_train(train_vecs, y_train, test_vecs, y_test):\n",
    "    clf = SVC(kernel='rbf', verbose=True)\n",
    "    clf.fit(train_vecs, y_train)\n",
    "    joblib.dump(clf, './datas/svm_data/model.pkl')\n",
    "    print(clf.score(test_vecs, y_test))\n",
    "train_vecs,y_train,test_vecs,y_test = get_data()\n",
    "svm_train(train_vecs,y_train,test_vecs,y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use svm model to predict...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\24544\\.conda\\envs\\chatbot\\lib\\site-packages\\ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n",
      "  import sys\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.] 1.0\n",
      "[1.] 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\24544\\.conda\\envs\\chatbot\\lib\\site-packages\\ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n",
      "  import sys\n"
     ]
    }
   ],
   "source": [
    "# 预测句子\n",
    "# load word2vec and smv model and use them to predict\n",
    "print(\"use svm model to predict...\")\n",
    "def svm_predict(str):\n",
    "    clf = joblib.load('./datas/svm_data/model.pkl')\n",
    "    model = Word2Vec.load('./datas/svm_data/w2v_model.pkl')\n",
    "    str_sege = cutwords(str)\n",
    "    words = [x for x in str_sege]\n",
    "    n_dim = 300\n",
    "    words_vecs = build_sentence_vector(words, n_dim, model)\n",
    "    result = clf.predict(words_vecs)\n",
    "    print(result,result[0])\n",
    "pre_str = \"手机刚买来设个密码，结果手机出了问题，输正确的密码却打不开，过了两三个小时才恢复正常\"\n",
    "svm_predict(pre_str)\n",
    "pre_str = \"一颗星都是给多了，简直是麻烦死了！！！！！必须要激活，不然不能使用，不是试用就是购买！！！！！！！！！！！\"\n",
    "svm_predict(pre_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
